{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Use DeepSurv from the repo\n",
    "import sys\n",
    "sys.path.append('../deepsurv')\n",
    "import deep_surv\n",
    "\n",
    "from deepsurv_logger import DeepSurvLogger, TensorboardLogger\n",
    "import utils\n",
    "import viz\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import lasagne\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Read in dataset\n",
    "First, I read in the dataset and print the first five elements to get a sense of what the dataset looks like"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "train_dataset_fp = './example_data.csv'\n",
    "train_df = pd.read_csv(train_dataset_fp)\n",
    "train_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Transform the dataset to \"DeepSurv\" format\n",
    "DeepSurv expects a dataset to be in the form:\n",
    "\n",
    "    {\n",
    "        'x': numpy array of float32\n",
    "        'e': numpy array of int32\n",
    "        't': numpy array of float32\n",
    "        'hr': (optional) numpy array of float32\n",
    "    }\n",
    "    \n",
    "You are providing me a csv, which I read in as a pandas dataframe. Then I convert the pandas dataframe into the DeepSurv dataset format above. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# event_col is the header in the df that represents the 'Event / Status' indicator\n",
    "# time_col is the header in the df that represents the event time\n",
    "def dataframe_to_deepsurv_ds(df, event_col = 'Event', time_col = 'Time'):\n",
    "    # Extract the event and time columns as numpy arrays\n",
    "    e = df[event_col].values.astype(np.int32)\n",
    "    t = df[time_col].values.astype(np.float32)\n",
    "\n",
    "    # Extract the patient's covariates as a numpy array\n",
    "    x_df = df.drop([event_col, time_col], axis = 1)\n",
    "    x = x_df.values.astype(np.float32)\n",
    "    \n",
    "    # Return the deep surv dataframe\n",
    "    return {\n",
    "        'x' : x,\n",
    "        'e' : e,\n",
    "        't' : t\n",
    "    }\n",
    "\n",
    "# If the headers of the csv change, you can replace the values of \n",
    "# 'event_col' and 'time_col' with the names of the new headers\n",
    "# You can also use this function on your training dataset, validation dataset, and testing dataset\n",
    "train_data = dataframe_to_deepsurv_ds(train_df, event_col = 'Event', time_col= 'Time')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now once you have your dataset all formatted, define you hyper_parameters as a Python dictionary. \n",
    "I'll provide you with some example hyper-parameters, but you should replace the values once you tune them to your specific dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "hyperparams = {\n",
    "    'L2_reg': 10.0,\n",
    "    'batch_norm': True,\n",
    "    'dropout': 0.4,\n",
    "    'hidden_layers_sizes': [25, 25],\n",
    "    'learning_rate': 1e-05,\n",
    "    'lr_decay': 0.001,\n",
    "    'momentum': 0.9,\n",
    "    'n_in': train_data['x'].shape[1],\n",
    "    'standardize': True\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once you prepared your dataset, and defined your hyper-parameters. Now it's time to train DeepSurv!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Create an instance of DeepSurv using the hyperparams defined above\n",
    "model = deep_surv.DeepSurv(**hyperparams)\n",
    "\n",
    "# DeepSurv can now leverage TensorBoard to monitor training and validation\n",
    "# This section of code is optional. If you don't want to use the tensorboard logger\n",
    "# Uncomment the below line, and comment out the other three lines: \n",
    "# logger = None\n",
    "\n",
    "experiment_name = 'test_experiment_sebastian'\n",
    "logdir = './logs/tensorboard/'\n",
    "logger = TensorboardLogger(experiment_name, logdir=logdir)\n",
    "\n",
    "# Now we train the model\n",
    "update_fn=lasagne.updates.nesterov_momentum # The type of optimizer to use. \\\n",
    "                                            # Check out http://lasagne.readthedocs.io/en/latest/modules/updates.html \\\n",
    "                                            # for other optimizers to use\n",
    "n_epochs = 2000\n",
    "\n",
    "# If you have validation data, you can add it as the second parameter to the function\n",
    "metrics = model.train(train_data, n_epochs=n_epochs, logger=logger, update_fn=update_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "There are two different ways to visualzie how the model trained:\n",
    "\n",
    "- Tensorboard (install ()[tensorboard]) which provides realtime metrics. Run the command in shell:\n",
    "   \n",
    "   `tensorboard --logdir './logs/tensorboard'`\n",
    "     \n",
    "     \n",
    "- Visualize the training functions post training (below)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# Print the final metrics\n",
    "print('Train C-Index:', metrics['c-index'][-1])\n",
    "# print('Valid C-Index: ',metrics['valid_c-index'][-1])\n",
    "\n",
    "# Plot the training / validation curves\n",
    "viz.plot_log(metrics)"
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python [deep_surv]",
   "language": "python",
   "name": "Python [deep_surv]"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
